!! Copyright (C) 2009,2010,2011,2012  Marco Restelli
!!
!! This file is part of:
!!   LDGH -- Local Hybridizable Discontinuous Galerkin toolkit
!!
!! LDGH is free software: you can redistribute it and/or modify it
!! under the terms of the GNU General Public License as published by
!! the Free Software Foundation, either version 3 of the License, or
!! (at your option) any later version.
!!
!! LDGH is distributed in the hope that it will be useful, but WITHOUT
!! ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
!! or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public
!! License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with LDGH. If not, see <http://www.gnu.org/licenses/>.
!!
!! author: Marco Restelli                   <marco.restelli@gmail.com>


!>\brief
!!
!! Simplified interface to MUMPS.
!!
!! \n
!!
!! This module provides a simplified interface to the MUMPS solver,
!! according to the general layout given in \c mod_linsolver_base.
!!
!! \note The type c_mumpspb collects the structures used by MUMPS to
!! define a linear system. This means that each variable of a type
!! extended from \c c_mumpspb can be used to represent a linear system
!! independently from other linear systems. The MUMPS internal data
!! for each linear system are initialized with a call to \c factor
!! (unless \c phase is present and equal to <tt>"factorzation"</tt>,
!! which requires a previous call for the initialization) and
!! deallocated with a call to \c clean.
!!
!! \section dist_mat Distributed matrix format
!!
!! MUMPS supports both centralized and distributed matrix format,
!! depending on the value of <tt>icntl(18)</tt>. This module
!! supports both formats, and for the distributed matrix format the
!! local to global indexing is specified in the field \c gij (this
!! field is not used when usign the centralized matrix format). Such
!! mapping allows repetition in the sense that multiple local indexes
!! can be mapped to a unique global index. In this case, the
!! corresponding entries in the rhs are summed, while the
!! corresponding entry in the solution is repeated for all the
!! repeated local indexes.
!!
!! \warning MUMPS indexes start from 1, while indexes in \c
!! mod_sparse start from 0. This means that a translation must be
!! made when passing from the \c t_tri format to the MUMPS format.
!!
!! \note To solve a different linear system on each processor, set the
!! communicator so that each processor is the unique menmber of a
!! separate communicator.
!!
!! \todo MUMPS supports many solver configurations; however
!! here we consider only the case of parallel resolution of the
!! system with centralized matrix storage.
!<----------------------------------------------------------------------
module mod_mumpsintf

!-----------------------------------------------------------------------

 use mod_utils, only: &
   t_realtime, my_second

 use mod_messages, only: &
   mod_messages_initialized, &
   error,   &
   warning, &
   info

 use mod_kinds, only: &
   mod_kinds_initialized, &
   wp

 use mod_sparse, only: &
   mod_sparse_initialized, &
   ! sparse types
   t_col, t_tri,&
   col2tri,     &
   clear

 use mod_mpi_utils, only: &
   mod_mpi_utils_initialized, &
   mpi_reduce, wp_mpi, mpi_sum, &
   mpi_bcast

 use mod_state_vars, only: &
   mod_state_vars_initialized, &
   c_stv

 use mod_output_control, only: &
   mod_output_control_initialized, &
   elapsed_format, base_name

 use mod_linsolver_base, only: &
   mod_linsolver_base_initialized, &
   c_linpb

!-----------------------------------------------------------------------
 
 implicit none

!-----------------------------------------------------------------------

! Module interface

 public :: &
   mod_mumpsintf_constructor, &
   mod_mumpsintf_destructor,  &
   mod_mumpsintf_initialized, &
   c_mumpspb

 private

!-----------------------------------------------------------------------

 include 'dmumps_struc.h'

 ! MUMPS double precision
 integer, parameter :: dp = kind(1.0d0)

 !> Linear MUMPS solver problem
 !!
 !! This type describes a linear system from the MUMPS viewpoint.
 type, extends(c_linpb), abstract :: c_mumpspb
  logical :: distributed !< see <tt>icntl(18)</tt>
  !> pivot ordering option: <tt>icntl(7)</tt>
  !! <table>
  !!  <tr> <th>value</th><th>method</th></tr>
  !!  <tr> <td align="center"> -1 </td><td align="center"> MUMPS'
  !!  default </td>
  !!  <tr> <td align="center"> 0 </td><td align="center"> AMD </td>
  !!  <tr> <td align="center"> 2 </td><td align="center"> AMF </td>
  !!  <tr> <td align="center"> 3 </td><td align="center"> SCOTCH </td>
  !!  <tr> <td align="center"> 4 </td><td align="center"> PORD </td>
  !!  <tr> <td align="center"> 5 </td><td align="center"> METIS </td>
  !!  <tr> <td align="center"> 6 </td><td align="center"> QAMD </td>
  !!  <tr> <td align="center"> 7 </td><td align="center"> MUMPS'
  !!  choice </td>
  !! </table>
  integer :: poo = -1
  !> matrix size
  integer :: gn
  !> system matrix: this is the global matrix if
  !! <tt>distributed.eqv..true.</tt>, otherwise the local one. In both
  !! cases, the indexes are 0-based, as always with \c t_col objects.
  !! If the matrix is distributed, the field \c gij is used to map
  !! local nodes to the global ones.
  type(t_col), pointer :: m
  real(wp), pointer :: rhs(:) !< local right-hand side
  !> local to global map, 0-based (as in \c mod_sparse)
  integer, pointer :: gij(:)
  !> solve \f$A^Tx=b\f$
  logical :: transposed_mat = .false.
  !> MPI communicator
  integer :: mpi_comm
  !> local to global map, 1-based (<tt>l2g_map = gij+1</tt>)
  integer, allocatable, private :: l2g_map(:)
  !> MUMPS internal system representation (see <tt>dmumps_struc.h</tt>)
  type(dmumps_struc), private :: mumps_par
  logical, private :: sys_set = .false. !< internal consistency check
 contains
  procedure, pass(s) :: factor => mumps_factor
  procedure, pass(s) :: solve  => mumps_solve
  procedure, pass(s) :: clean  => mumps_clean
  procedure(i_xassign), deferred, pass(s) :: xassign
 end type c_mumpspb

 !> Convert the MUMPS solution into a \c c_stv object
 abstract interface
  pure subroutine i_xassign(x,s,mumps_x)
   import :: wp, c_stv, c_mumpspb
   implicit none
   real(wp),         intent(in) :: mumps_x(:)
   class(c_mumpspb), intent(inout) :: s
   class(c_stv),     intent(inout) :: x
  end subroutine i_xassign
 end interface

 real(t_realtime) :: t0, t1
 character(len=1000) :: message
 logical, protected :: &
   mod_mumpsintf_initialized = .false.
 character(len=*), parameter :: &
   this_mod_name = 'mod_mumpsintf'

!-----------------------------------------------------------------------

contains

!-----------------------------------------------------------------------

 subroutine mod_mumpsintf_constructor()
  character(len=*), parameter :: &
    this_sub_name = 'constructor'

   !Consistency checks ---------------------------
   if( (mod_messages_initialized.eqv..false.) .or. &
          (mod_kinds_initialized.eqv..false.) .or. &
         (mod_sparse_initialized.eqv..false.) .or. &
      (mod_mpi_utils_initialized.eqv..false.) .or. &
 (mod_output_control_initialized.eqv..false.) .or. &
 (mod_linsolver_base_initialized.eqv..false.) ) then
     call error(this_sub_name,this_mod_name, &
                'Not all the required modules are initialized.')
   endif
   if(mod_mumpsintf_initialized.eqv..true.) then
     call warning(this_sub_name,this_mod_name, &
                  'Module is already initialized.')
   endif
   !----------------------------------------------

   mod_mumpsintf_initialized = .true.
 end subroutine mod_mumpsintf_constructor

!-----------------------------------------------------------------------
 
 subroutine mod_mumpsintf_destructor()
  character(len=*), parameter :: &
    this_sub_name = 'destructor'
   
   !Consistency checks ---------------------------
   if(mod_mumpsintf_initialized.eqv..false.) then
     call error(this_sub_name,this_mod_name, &
                'This module is not initialized.')
   endif
   !----------------------------------------------

   mod_mumpsintf_initialized = .false.
 end subroutine mod_mumpsintf_destructor

!-----------------------------------------------------------------------

 !> Factor matrix \c m
 subroutine mumps_factor(s,phase)
  class(c_mumpspb), intent(inout) :: s
  character(len=*), intent(in), optional :: phase
 
  type(t_tri) :: m_tri ! used to recover the MUMPS format
  logical, parameter :: write_mat_and_stop = .false.
  character(len=*), parameter :: dump_file = '.mumps-mat-dump'
  logical :: write_and_stop
  integer :: ierr
  character(len=*), parameter :: &
    this_sub_name = 'mumps_factor'
 
   write_and_stop = .false.

   ! Check whether only factorization is required
   factorization_if: if(present(phase)) then
     if(trim(phase).eq.'factorization') then

       if(.not.s%sys_set) call error(this_sub_name,this_mod_name, &
          'Trying to factorize a matrix which is not yet defined.')

       ! Update the matrix coefficients
       if(s%distributed) then ! all procs
         m_tri = col2tri(s%m)
         s%mumps_par%a_loc = real(m_tri%tx,dp)
         call clear(m_tri)
       else ! only root proc
         if(s%mumps_par%myid.eq.0) then
           s%mumps_par%a   = real(m_tri%tx,dp)
           call clear(m_tri)
         endif
       endif

       if(write_mat_and_stop) write_and_stop = .true.

       t0 = my_second()
       s%mumps_par%job = 2 ! factorization
       call dmumps(s%mumps_par)
       t1 = my_second()
       write(message,elapsed_format)                                   &
               'Completed MUMPS factorization: elapsed time ',t1-t0,'s.'
       call info(this_mod_name,this_sub_name,message)

       if(write_and_stop) call complete_matrix_dump()
       return
     endif
   endif factorization_if

   if(s%sys_set) call warning(this_sub_name,this_mod_name, &
       'For this matrix an analysis has already been done!')

   ! Set the mumps_par field
   s%mumps_par%comm = s%mpi_comm

   ! Initialize an instance of the package
   if(.not.s%sys_set) then
     s%mumps_par%job = -1 ! initialization
     s%mumps_par%sym =  0 ! unsymmetric
     s%mumps_par%par =  1 ! host participates in computations
     t0 = my_second()
     call dmumps(s%mumps_par)
     t1 = my_second()
     write(message,elapsed_format)                                   &
             'Completed MUMPS initialization: elapsed time ',t1-t0,'s.'
     call info(this_mod_name,this_sub_name,message)
   endif

   ! Set output to stdout: typically unit 6
   s%mumps_par%icntl(1) = -6
   s%mumps_par%icntl(2) = -6
   s%mumps_par%icntl(3) = -6

   ! Set the pivot order option
   if(s%poo.ge.0) s%mumps_par%icntl(7) = s%poo

   ! accordind to the MUMPS documentation, write_problem must be set
   ! before the analysis phase
   if(write_mat_and_stop) &
         s%mumps_par%write_problem = trim(base_name)//dump_file

   distributed_if: if(s%distributed) then

     ! Set the module variable
     if(.not.s%sys_set) allocate(s%l2g_map(size(s%gij)))
     s%l2g_map = s%gij+1

     m_tri = col2tri(s%m)

     s%mumps_par%n       = s%gn
     s%mumps_par%nz_loc  = m_tri%nz
     if(.not.s%sys_set) allocate(s%mumps_par%irn_loc(m_tri%nz))
     s%mumps_par%irn_loc = s%l2g_map(m_tri%ti+1)
     if(.not.s%sys_set) allocate(s%mumps_par%jcn_loc(m_tri%nz))
     s%mumps_par%jcn_loc = s%l2g_map(m_tri%tj+1)
     ! the real kind can be different
     if(.not.s%sys_set) allocate(s%mumps_par%a_loc  (m_tri%nz))
     s%mumps_par%a_loc   = real(m_tri%tx,dp)

     call clear(m_tri)

     s%mumps_par%icntl(18) = 3 ! distributed matrix
   else
     ! the matrix is defined only on the master process
     if(s%mumps_par%myid.eq.0) then

       m_tri = col2tri(s%m)

       s%mumps_par%n  = m_tri%n
       s%mumps_par%nz = m_tri%nz
       if(.not.s%sys_set) allocate(s%mumps_par%irn(m_tri%nz))
       s%mumps_par%irn = m_tri%ti+1
       if(.not.s%sys_set) allocate(s%mumps_par%jcn(m_tri%nz))
       s%mumps_par%jcn = m_tri%tj+1
       ! the real kind can be different
       if(.not.s%sys_set) allocate(s%mumps_par%a(m_tri%nz))
       s%mumps_par%a = real(m_tri%tx,dp)

       call clear(m_tri)

     endif
     s%mumps_par%icntl(18) = 0 ! centralized matrix
   endif distributed_if

   s%mumps_par%icntl(5) = 0 ! assembled format

   if(present(phase)) then
     if(trim(phase).eq.'analysis') then
       s%mumps_par%job = 1 ! analysis
     else
     call error(this_sub_name,this_mod_name, &
            'Unknown phase "'//trim(phase)//'".')
     endif
   else
     if(write_mat_and_stop) write_and_stop = .true.

     s%mumps_par%job = 4 ! analysis and factorization
   endif
   t0 = my_second()
   call dmumps(s%mumps_par)
   t1 = my_second()
   write(message,elapsed_format) 'Completed MUMPS analysis/analysis'// &
                          ' and factorization: elapsed time ',t1-t0,'s.'
   call info(this_mod_name,this_sub_name,message)

   s%sys_set = .true.

   ! don't stop if phase if "analysis"
   if(write_and_stop) call complete_matrix_dump()

 contains

  subroutine complete_matrix_dump()
    call mpi_barrier(s%mpi_comm,ierr)
    call error(this_sub_name,this_mod_name,                        &
      'Matrix dumped in file "' // trim(s%mumps_par%write_problem) &
      // '", exiting.')
  end subroutine complete_matrix_dump

 end subroutine mumps_factor
 
!-----------------------------------------------------------------------
 
 !> Solve the linear system
 !!
 !! When <tt>distributed.eqv..true.</tt> both the right hand side and
 !! the solution are treated as distributed vectors. This means that:
 !! <ul>
 !!  <li> The solution is first centralized on the master process and
 !!  then distributed to the remaining processes, so that each process
 !!  gets the solution components corresponding to the local column
 !!  indexes in the matrix \c m. This is motivated by the fact that
 !!  MUMPS can not perform error analysis and iterative refinement if
 !!  the solution is not centralized.
 !!  <li> The right hand side must be collected on the master process,
 !!  since MUMPS doesn't handle a distributed rhs.
 !! </ul>
 !<
 subroutine mumps_solve(x,s)
  class(c_mumpspb), intent(inout) :: s
  class(c_stv),     intent(inout) :: x

  integer :: i, ierr
  real(wp), allocatable :: recv_buff(:), send_buff(:)
  character(len=*), parameter :: &
    this_sub_name = 'mumps_solve'

   if(.not.s%sys_set) call error(this_sub_name,this_mod_name, &
   'Trying to solve a system while the matrix is not defined.')

   ! set up the rhs
   if(s%distributed) then
     allocate(recv_buff(s%mumps_par%n)) ! global size
     allocate(send_buff(s%mumps_par%n)); send_buff = 0.0_wp

     ! careful: l2g_map can have repetitions
     send_buff = 0.0_wp
     do i=1,size(s%rhs)
       send_buff(s%l2g_map(i)) = send_buff(s%l2g_map(i)) + s%rhs(i)
     enddo
     call mpi_reduce( send_buff , recv_buff , s%mumps_par%n , &
                     wp_mpi , mpi_sum , 0 , s%mpi_comm , ierr )
     deallocate(send_buff) ! avoid allocating many large arrays
     if(s%mumps_par%myid.eq.0) then ! no need for rhs on other procs.
       allocate(s%mumps_par%rhs(s%mumps_par%n)) ! global size
       s%mumps_par%rhs = real(recv_buff,dp)
     endif
     deallocate(recv_buff)
   else
     if(s%mumps_par%myid.eq.0) then ! no need for rhs on other procs.
       allocate(s%mumps_par%rhs(s%mumps_par%n))
       s%mumps_par%rhs = real(s%rhs,dp)
     endif
   endif

   s%mumps_par%icntl(20) = 0 ! centralized dense rhs
   s%mumps_par%icntl(21) = 0 ! centralized dense solution
   s%mumps_par%icntl(10) = 3 ! iterative refinement

   s%mumps_par%icntl(9) = 1
   if(s%transposed_mat) s%mumps_par%icntl(9) = 2
   
   t0 = my_second()
   s%mumps_par%job = 3 ! solve
   call dmumps(s%mumps_par)
   t1 = my_second()
   write(message,elapsed_format)                           &
           'Completed MUMPS solve: elapsed time ',t1-t0,'s.'
   call info(this_mod_name,this_sub_name,message)

   ! set up the solution (remember that MUMPS uses mumps_par%rhs to
   ! return the computed solution of the linear system)
   if(s%distributed) then
     ! The optimal solution would be a call to mpi_scatterv, where the
     ! root sends to each processor the information that it needs.
     ! However, this implies that the root keeps track of the rows
     ! seen by each process, which would be nontrivial. A simpler
     ! (likely less efficient) solution is adopted here, having the
     ! root process sending the complete solution to all the
     ! processes.
     allocate(send_buff(s%mumps_par%n))
     if(s%mumps_par%myid.eq.0) send_buff = real(s%mumps_par%rhs,wp)
     call mpi_bcast(send_buff, s%mumps_par%n, wp_mpi,0, s%mpi_comm,ierr)
     ! pick the needed components
     call s%xassign(x,send_buff(s%l2g_map))
     deallocate(send_buff)
   else
     if(s%mumps_par%myid.eq.0) then
       call s%xassign(x,real(s%mumps_par%rhs,wp))
     endif
   endif
   if(s%mumps_par%myid.eq.0) deallocate(s%mumps_par%rhs)
 
 end subroutine mumps_solve
 
!-----------------------------------------------------------------------
 
 subroutine mumps_clean(s)
  class(c_mumpspb), intent(inout) :: s

  character(len=*), parameter :: &
    this_sub_name = 'mumps_clean'
 
   if(s%sys_set) then
     if(s%distributed) then
       deallocate(s%mumps_par%irn_loc)
       deallocate(s%mumps_par%jcn_loc)
       deallocate(s%mumps_par%a_loc)
       deallocate(s%l2g_map)
     else
       if(s%mumps_par%myid.eq.0) then
         deallocate(s%mumps_par%irn)
         deallocate(s%mumps_par%jcn)
         deallocate(s%mumps_par%a)
       endif
     endif

     s%mumps_par%job = -2 ! clean-up
     call dmumps(s%mumps_par)

     s%sys_set = .false.
   endif

 end subroutine mumps_clean
 
!-----------------------------------------------------------------------

end module mod_mumpsintf

